{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import sys\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from pyhealth.models import Transformer\n",
    "from pyhealth.interpret.methods.chefer import CheferRelevance\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO: Pandarallel will run on 64 workers.\n",
      "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n",
      "finish basic patient information parsing : 1.2506768703460693s\n",
      "finish parsing DIAGNOSES_ICD : 1.6561415195465088s\n",
      "finish parsing PROCEDURES_ICD : 0.8813426494598389s\n",
      "finish parsing PRESCRIPTIONS : 7.1212241649627686s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Mapping codes: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 24326.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Statistics of base dataset (dev=True):\n",
      "\t- Dataset: MIMIC3Dataset\n",
      "\t- Number of patients: 1000\n",
      "\t- Number of visits: 1054\n",
      "\t- Number of visits per patient: 1.0540\n",
      "\t- Number of events per visit in DIAGNOSES_ICD: 9.2068\n",
      "\t- Number of events per visit in PROCEDURES_ICD: 3.0380\n",
      "\t- Number of events per visit in PRESCRIPTIONS: 29.2457\n",
      "\n",
      "\n",
      "Statistics of base dataset (dev=True):\n",
      "\t- Dataset: MIMIC3Dataset\n",
      "\t- Number of patients: 1000\n",
      "\t- Number of visits: 1054\n",
      "\t- Number of visits per patient: 1.0540\n",
      "\t- Number of events per visit in DIAGNOSES_ICD: 9.2068\n",
      "\t- Number of events per visit in PROCEDURES_ICD: 3.0380\n",
      "\t- Number of events per visit in PRESCRIPTIONS: 29.2457\n",
      "\n",
      "\n",
      "dataset.patients: patient_id -> <Patient>\n",
      "\n",
      "<Patient>\n",
      "    - visits: visit_id -> <Visit> \n",
      "    - other patient-level info\n",
      "    \n",
      "    <Visit>\n",
      "        - event_list_dict: table_name -> List[Event]\n",
      "        - other visit-level info\n",
      "    \n",
      "        <Event>\n",
      "            - code: str\n",
      "            - other event-level info\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Generating samples for length_of_stay_prediction_mimic3_fn: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 24884.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Statistics of sample dataset:\n",
      "\t- Dataset: MIMIC3Dataset\n",
      "\t- Task: length_of_stay_prediction_mimic3_fn\n",
      "\t- Number of samples: 896\n",
      "\t- Number of patients: 886\n",
      "\t- Number of visits: 896\n",
      "\t- Number of visits per patient: 1.0113\n",
      "\t- conditions:\n",
      "\t\t- Number of conditions per sample: 9.7377\n",
      "\t\t- Number of unique conditions: 1699\n",
      "\t\t- Distribution of conditions (Top-10): [('4019', 324), ('41401', 167), ('25000', 151), ('4280', 147), ('5849', 121), ('2724', 118), ('42731', 110), ('51881', 89), ('53081', 88), ('2720', 85)]\n",
      "\t- procedures:\n",
      "\t\t- Number of procedures per sample: 3.5647\n",
      "\t\t- Number of unique procedures: 558\n",
      "\t\t- Distribution of procedures (Top-10): [('3893', 155), ('9671', 139), ('9604', 123), ('9904', 100), ('8856', 85), ('966', 64), ('3615', 61), ('3961', 59), ('3722', 58), ('3891', 55)]\n",
      "\t- drugs:\n",
      "\t\t- Number of drugs per sample: 32.5804\n",
      "\t\t- Number of unique drugs: 1911\n",
      "\t\t- Distribution of drugs (Top-10): [('00338004904', 388), ('00008084199', 362), ('00338004903', 317), ('00338001702', 305), ('00641040025', 285), ('00338004938', 271), ('00338004902', 269), ('63323026201', 268), ('00904516561', 252), ('00182844789', 242)]\n",
      "\t- label:\n",
      "\t\t- Number of label per sample: 1.0000\n",
      "\t\t- Number of unique label: 10\n",
      "\t\t- Distribution of label (Top-10): [(0, 360), (1, 109), (2, 94), (8, 85), (3, 63), (4, 44), (7, 43), (6, 43), (5, 42), (9, 13)]\n",
      "Statistics of sample dataset:\n",
      "\t- Dataset: MIMIC3Dataset\n",
      "\t- Task: length_of_stay_prediction_mimic3_fn\n",
      "\t- Number of samples: 896\n",
      "\t- Number of patients: 886\n",
      "\t- Number of visits: 896\n",
      "\t- Number of visits per patient: 1.0113\n",
      "\t- conditions:\n",
      "\t\t- Number of conditions per sample: 9.7377\n",
      "\t\t- Number of unique conditions: 1699\n",
      "\t\t- Distribution of conditions (Top-10): [('4019', 324), ('41401', 167), ('25000', 151), ('4280', 147), ('5849', 121), ('2724', 118), ('42731', 110), ('51881', 89), ('53081', 88), ('2720', 85)]\n",
      "\t- procedures:\n",
      "\t\t- Number of procedures per sample: 3.5647\n",
      "\t\t- Number of unique procedures: 558\n",
      "\t\t- Distribution of procedures (Top-10): [('3893', 155), ('9671', 139), ('9604', 123), ('9904', 100), ('8856', 85), ('966', 64), ('3615', 61), ('3961', 59), ('3722', 58), ('3891', 55)]\n",
      "\t- drugs:\n",
      "\t\t- Number of drugs per sample: 32.5804\n",
      "\t\t- Number of unique drugs: 1911\n",
      "\t\t- Distribution of drugs (Top-10): [('00338004904', 388), ('00008084199', 362), ('00338004903', 317), ('00338001702', 305), ('00641040025', 285), ('00338004938', 271), ('00338004902', 269), ('63323026201', 268), ('00904516561', 252), ('00182844789', 242)]\n",
      "\t- label:\n",
      "\t\t- Number of label per sample: 1.0000\n",
      "\t\t- Number of unique label: 10\n",
      "\t\t- Distribution of label (Top-10): [(0, 360), (1, 109), (2, 94), (8, 85), (3, 63), (4, 44), (7, 43), (6, 43), (5, 42), (9, 13)]\n",
      "Testing MIMIC3 STUFF\n",
      "{'visit_id': '100003', 'patient_id': '4', 'conditions': [['4019', '96501', '2851', '29281', '8208', '53230']], 'procedures': [['9904', '8604', '5114']], 'drugs': [['00487950125', '00338004903', '00074665305', '51079001920', '00338268975', '00006096328', '00338008904', '00045006701', '51079000522', '00173044202', '00456402063', '00731040106', '49884055001', '00186021003', '00049343041']], 'label': 8}\n",
      "----\n",
      "Transformer(\n",
      "  (embeddings): ModuleDict(\n",
      "    (conditions): Embedding(1701, 128, padding_idx=0)\n",
      "    (procedures): Embedding(560, 128, padding_idx=0)\n",
      "    (drugs): Embedding(1913, 128, padding_idx=0)\n",
      "  )\n",
      "  (linear_layers): ModuleDict()\n",
      "  (transformer): ModuleDict(\n",
      "    (conditions): TransformerLayer(\n",
      "      (transformer): ModuleList(\n",
      "        (0): TransformerBlock(\n",
      "          (attention): MultiHeadedAttention(\n",
      "            (linear_layers): ModuleList(\n",
      "              (0): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (1): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (2): Linear(in_features=128, out_features=128, bias=False)\n",
      "            )\n",
      "            (output_linear): Linear(in_features=128, out_features=128, bias=False)\n",
      "            (attention): Attention()\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "          (feed_forward): PositionwiseFeedForward(\n",
      "            (w_1): Linear(in_features=128, out_features=512, bias=True)\n",
      "            (w_2): Linear(in_features=512, out_features=128, bias=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "            (activation): GELU(approximate='none')\n",
      "          )\n",
      "          (input_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (output_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (dropout): Dropout(p=0.5, inplace=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (procedures): TransformerLayer(\n",
      "      (transformer): ModuleList(\n",
      "        (0): TransformerBlock(\n",
      "          (attention): MultiHeadedAttention(\n",
      "            (linear_layers): ModuleList(\n",
      "              (0): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (1): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (2): Linear(in_features=128, out_features=128, bias=False)\n",
      "            )\n",
      "            (output_linear): Linear(in_features=128, out_features=128, bias=False)\n",
      "            (attention): Attention()\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "          (feed_forward): PositionwiseFeedForward(\n",
      "            (w_1): Linear(in_features=128, out_features=512, bias=True)\n",
      "            (w_2): Linear(in_features=512, out_features=128, bias=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "            (activation): GELU(approximate='none')\n",
      "          )\n",
      "          (input_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (output_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (dropout): Dropout(p=0.5, inplace=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (drugs): TransformerLayer(\n",
      "      (transformer): ModuleList(\n",
      "        (0): TransformerBlock(\n",
      "          (attention): MultiHeadedAttention(\n",
      "            (linear_layers): ModuleList(\n",
      "              (0): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (1): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (2): Linear(in_features=128, out_features=128, bias=False)\n",
      "            )\n",
      "            (output_linear): Linear(in_features=128, out_features=128, bias=False)\n",
      "            (attention): Attention()\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "          (feed_forward): PositionwiseFeedForward(\n",
      "            (w_1): Linear(in_features=128, out_features=512, bias=True)\n",
      "            (w_2): Linear(in_features=512, out_features=128, bias=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "            (activation): GELU(approximate='none')\n",
      "          )\n",
      "          (input_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (output_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (dropout): Dropout(p=0.5, inplace=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (fc): Linear(in_features=384, out_features=10, bias=True)\n",
      ")\n",
      "Transformer(\n",
      "  (embeddings): ModuleDict(\n",
      "    (conditions): Embedding(1701, 128, padding_idx=0)\n",
      "    (procedures): Embedding(560, 128, padding_idx=0)\n",
      "    (drugs): Embedding(1913, 128, padding_idx=0)\n",
      "  )\n",
      "  (linear_layers): ModuleDict()\n",
      "  (transformer): ModuleDict(\n",
      "    (conditions): TransformerLayer(\n",
      "      (transformer): ModuleList(\n",
      "        (0): TransformerBlock(\n",
      "          (attention): MultiHeadedAttention(\n",
      "            (linear_layers): ModuleList(\n",
      "              (0): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (1): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (2): Linear(in_features=128, out_features=128, bias=False)\n",
      "            )\n",
      "            (output_linear): Linear(in_features=128, out_features=128, bias=False)\n",
      "            (attention): Attention()\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "          (feed_forward): PositionwiseFeedForward(\n",
      "            (w_1): Linear(in_features=128, out_features=512, bias=True)\n",
      "            (w_2): Linear(in_features=512, out_features=128, bias=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "            (activation): GELU(approximate='none')\n",
      "          )\n",
      "          (input_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (output_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (dropout): Dropout(p=0.5, inplace=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (procedures): TransformerLayer(\n",
      "      (transformer): ModuleList(\n",
      "        (0): TransformerBlock(\n",
      "          (attention): MultiHeadedAttention(\n",
      "            (linear_layers): ModuleList(\n",
      "              (0): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (1): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (2): Linear(in_features=128, out_features=128, bias=False)\n",
      "            )\n",
      "            (output_linear): Linear(in_features=128, out_features=128, bias=False)\n",
      "            (attention): Attention()\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "          (feed_forward): PositionwiseFeedForward(\n",
      "            (w_1): Linear(in_features=128, out_features=512, bias=True)\n",
      "            (w_2): Linear(in_features=512, out_features=128, bias=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "            (activation): GELU(approximate='none')\n",
      "          )\n",
      "          (input_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (output_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (dropout): Dropout(p=0.5, inplace=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (drugs): TransformerLayer(\n",
      "      (transformer): ModuleList(\n",
      "        (0): TransformerBlock(\n",
      "          (attention): MultiHeadedAttention(\n",
      "            (linear_layers): ModuleList(\n",
      "              (0): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (1): Linear(in_features=128, out_features=128, bias=False)\n",
      "              (2): Linear(in_features=128, out_features=128, bias=False)\n",
      "            )\n",
      "            (output_linear): Linear(in_features=128, out_features=128, bias=False)\n",
      "            (attention): Attention()\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "          (feed_forward): PositionwiseFeedForward(\n",
      "            (w_1): Linear(in_features=128, out_features=512, bias=True)\n",
      "            (w_2): Linear(in_features=512, out_features=128, bias=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "            (activation): GELU(approximate='none')\n",
      "          )\n",
      "          (input_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (output_sublayer): SublayerConnection(\n",
      "            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.5, inplace=False)\n",
      "          )\n",
      "          (dropout): Dropout(p=0.5, inplace=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (fc): Linear(in_features=384, out_features=10, bias=True)\n",
      ")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: ['accuracy', 'f1_weighted']\n",
      "Device: cuda\n",
      "\n",
      "Training:\n",
      "Batch size: 64\n",
      "Optimizer: <class 'torch.optim.adamw.AdamW'>\n",
      "Optimizer params: {'lr': 0.001}\n",
      "Weight decay: 0.0\n",
      "Max grad norm: None\n",
      "Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7fbc0e1ea1c0>\n",
      "Monitor: accuracy\n",
      "Monitor criterion: max\n",
      "Epochs: 30\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2b2840f2161541e9a20830eea9e6f96a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 0 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-0, step-12 ---\n",
      "loss: 4.5351\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 216.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-0, step-12 ---\n",
      "accuracy: 0.3000\n",
      "f1_weighted: 0.2274\n",
      "loss: 3.2804\n",
      "New best accuracy score (0.3000) at epoch-0, step-12\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0f5a0acf74d4b89a3a732524d5fc8cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 1 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-1, step-24 ---\n",
      "loss: 3.5789\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 146.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-1, step-24 ---\n",
      "accuracy: 0.2556\n",
      "f1_weighted: 0.2191\n",
      "loss: 3.0746\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb089a003ca34a4299df417d147b7937",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 2 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-2, step-36 ---\n",
      "loss: 2.9425\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 230.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-2, step-36 ---\n",
      "accuracy: 0.2556\n",
      "f1_weighted: 0.2077\n",
      "loss: 3.0053\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f9723a0f446447d4a04c161dec85c0c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 3 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-3, step-48 ---\n",
      "loss: 2.3530\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 240.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-3, step-48 ---\n",
      "accuracy: 0.2778\n",
      "f1_weighted: 0.2140\n",
      "loss: 3.0312\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a4b59a50050429baca3d8569a588bd7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 4 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-4, step-60 ---\n",
      "loss: 2.1849\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 234.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-4, step-60 ---\n",
      "accuracy: 0.2778\n",
      "f1_weighted: 0.2194\n",
      "loss: 2.9949\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "326433ebc2fd47a2ad97afab3404fb5e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 5 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-5, step-72 ---\n",
      "loss: 1.8268\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 154.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-5, step-72 ---\n",
      "accuracy: 0.2889\n",
      "f1_weighted: 0.2257\n",
      "loss: 2.9785\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7785202408ec4e3d99ee02ea95ab08c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 6 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-6, step-84 ---\n",
      "loss: 1.6957\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 132.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-6, step-84 ---\n",
      "accuracy: 0.2889\n",
      "f1_weighted: 0.2191\n",
      "loss: 3.0148\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89087f96cc0d45139a6fb952cf18abe3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 7 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-7, step-96 ---\n",
      "loss: 1.4070\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 238.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-7, step-96 ---\n",
      "accuracy: 0.2889\n",
      "f1_weighted: 0.2228\n",
      "loss: 3.0534\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "686d27c1bf0e403ebd11a8341f4f3295",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 8 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-8, step-108 ---\n",
      "loss: 1.3471\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 238.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-8, step-108 ---\n",
      "accuracy: 0.2667\n",
      "f1_weighted: 0.2175\n",
      "loss: 2.9882\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f101bd7d0bed4f649d23b41312e29c47",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 9 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-9, step-120 ---\n",
      "loss: 1.2055\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 144.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-9, step-120 ---\n",
      "accuracy: 0.3000\n",
      "f1_weighted: 0.2201\n",
      "loss: 3.0590\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7147dc1a76834adeb636184378b81b86",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 10 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-10, step-132 ---\n",
      "loss: 1.0639\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 140.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-10, step-132 ---\n",
      "accuracy: 0.3222\n",
      "f1_weighted: 0.2504\n",
      "loss: 3.1292\n",
      "New best accuracy score (0.3222) at epoch-10, step-132\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "15049a66d593425487f35090dc3808be",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 11 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-11, step-144 ---\n",
      "loss: 0.9251\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 237.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-11, step-144 ---\n",
      "accuracy: 0.3333\n",
      "f1_weighted: 0.2600\n",
      "loss: 3.2008\n",
      "New best accuracy score (0.3333) at epoch-11, step-144\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb2bfaa7482b47aaa9b5583275efc388",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 12 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-12, step-156 ---\n",
      "loss: 0.7912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 238.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-12, step-156 ---\n",
      "accuracy: 0.3444\n",
      "f1_weighted: 0.2845\n",
      "loss: 3.2917\n",
      "New best accuracy score (0.3444) at epoch-12, step-156\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f62b7c1f53ab47289e6e2eedb1bc2324",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 13 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-13, step-168 ---\n",
      "loss: 0.9224\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 240.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-13, step-168 ---\n",
      "accuracy: 0.2889\n",
      "f1_weighted: 0.2432\n",
      "loss: 3.3870\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "54c3fec8f74d4e1b8060c441cfca54f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 14 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-14, step-180 ---\n",
      "loss: 0.6438\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 237.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-14, step-180 ---\n",
      "accuracy: 0.3000\n",
      "f1_weighted: 0.2468\n",
      "loss: 3.4254\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f4741619b7cc43a493dcb8244829fd58",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 15 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-15, step-192 ---\n",
      "loss: 0.7075\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 237.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-15, step-192 ---\n",
      "accuracy: 0.2889\n",
      "f1_weighted: 0.2467\n",
      "loss: 3.5565\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fbf2aab680894cd0adc5b619d0d6c86c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 16 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-16, step-204 ---\n",
      "loss: 0.6729\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 236.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-16, step-204 ---\n",
      "accuracy: 0.2889\n",
      "f1_weighted: 0.2240\n",
      "loss: 3.6905\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1e52b2585bb84194b722b1c8ff856bba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 17 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-17, step-216 ---\n",
      "loss: 0.5472\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 142.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-17, step-216 ---\n",
      "accuracy: 0.2667\n",
      "f1_weighted: 0.2157\n",
      "loss: 3.7475\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f5a3c03d7344807938c654af5c509fe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 18 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-18, step-228 ---\n",
      "loss: 0.4842\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 134.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-18, step-228 ---\n",
      "accuracy: 0.2444\n",
      "f1_weighted: 0.1998\n",
      "loss: 3.8161\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b881cb783b6349349d5bad1f91a59d2f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 19 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-19, step-240 ---\n",
      "loss: 0.4673\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 141.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-19, step-240 ---\n",
      "accuracy: 0.2444\n",
      "f1_weighted: 0.2001\n",
      "loss: 3.8863\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f587c0472da54b7db133023613114e08",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 20 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-20, step-252 ---\n",
      "loss: 0.4145\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 236.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-20, step-252 ---\n",
      "accuracy: 0.2667\n",
      "f1_weighted: 0.2293\n",
      "loss: 3.9445\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "53d34c5202474dba919014576d053821",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 21 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-21, step-264 ---\n",
      "loss: 0.3754\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 236.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-21, step-264 ---\n",
      "accuracy: 0.3000\n",
      "f1_weighted: 0.2449\n",
      "loss: 4.0580\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bf31a5bdc0054ca78ce4f06ba4dc0cb4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 22 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-22, step-276 ---\n",
      "loss: 0.3107\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 178.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-22, step-276 ---\n",
      "accuracy: 0.2889\n",
      "f1_weighted: 0.2495\n",
      "loss: 4.0857\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6688386a6a894abc8500c335cf2c8068",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 23 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-23, step-288 ---\n",
      "loss: 0.3031\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 235.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-23, step-288 ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy: 0.2889\n",
      "f1_weighted: 0.2426\n",
      "loss: 4.2967\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7cab1849dd4749359203f155205bc6d8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 24 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-24, step-300 ---\n",
      "loss: 0.2953\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 240.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-24, step-300 ---\n",
      "accuracy: 0.2556\n",
      "f1_weighted: 0.2128\n",
      "loss: 4.3623\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "273eefe5ecb04339ab22d83c765b992b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 25 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-25, step-312 ---\n",
      "loss: 0.2553\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 129.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-25, step-312 ---\n",
      "accuracy: 0.2556\n",
      "f1_weighted: 0.2172\n",
      "loss: 4.4135\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "14bbd9386dc947f29ecc244dc86ecf29",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 26 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-26, step-324 ---\n",
      "loss: 0.2691\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 194.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-26, step-324 ---\n",
      "accuracy: 0.2556\n",
      "f1_weighted: 0.2107\n",
      "loss: 4.5224\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1446f2145bf24dff8aa3ca7b2352c9c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 27 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-27, step-336 ---\n",
      "loss: 0.2153\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 183.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-27, step-336 ---\n",
      "accuracy: 0.2444\n",
      "f1_weighted: 0.2200\n",
      "loss: 4.6278\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e20460cdf07749118c971f9b4bb01086",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 28 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-28, step-348 ---\n",
      "loss: 0.2328\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 235.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-28, step-348 ---\n",
      "accuracy: 0.2444\n",
      "f1_weighted: 0.2081\n",
      "loss: 4.7121\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3d8ca943b9284c1d844a0d1c9948b751",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 29 / 30:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-29, step-360 ---\n",
      "loss: 0.1637\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 233.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-29, step-360 ---\n",
      "accuracy: 0.2333\n",
      "f1_weighted: 0.1972\n",
      "loss: 4.7489\n",
      "Loaded best model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'visit_id': ['100003'], 'patient_id': ['4'], 'conditions': [[['4019', '96501', '2851', '29281', '8208', '53230']]], 'procedures': [[['9904', '8604', '5114']]], 'drugs': [[['00487950125', '00338004903', '00074665305', '51079001920', '00338268975', '00006096328', '00338008904', '00045006701', '51079000522', '00173044202', '00456402063', '00731040106', '49884055001', '00186021003', '00049343041']]], 'label': [8]}\n",
      "{'conditions': tensor([[1.]], device='cuda:0'), 'procedures': tensor([[1.0881]], device='cuda:0'), 'drugs': tensor([[1.]], device='cuda:0')}\n"
     ]
    }
   ],
   "source": [
    "from pyhealth.datasets import MIMIC3Dataset\n",
    "\n",
    "mimic3_ds = MIMIC3Dataset(\n",
    "        root=\"https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/\",\n",
    "        tables=[\"DIAGNOSES_ICD\", \"PROCEDURES_ICD\", \"PRESCRIPTIONS\"],\n",
    "        dev=True,\n",
    ")\n",
    "\n",
    "print (mimic3_ds.stat())\n",
    "# data format\n",
    "mimic3_ds.info()\n",
    "from pyhealth.tasks import length_of_stay_prediction_mimic3_fn\n",
    "\n",
    "mimic3_ds = mimic3_ds.set_task(task_fn=length_of_stay_prediction_mimic3_fn)\n",
    "# stats info\n",
    "print (mimic3_ds.stat())\n",
    "\n",
    "\n",
    "{\n",
    "    \"patient_id\": \"p001\",\n",
    "    \"visit_id\": \"v001\",\n",
    "    \"diagnoses\": [...],\n",
    "    \"labs\": [...],\n",
    "    \"procedures\": [...],\n",
    "    \"label\": 1,\n",
    "}\n",
    "\n",
    "from pyhealth.datasets.splitter import split_by_patient\n",
    "from pyhealth.datasets import split_by_patient, get_dataloader\n",
    "\n",
    "# data split\n",
    "train_dataset, val_dataset, test_dataset = split_by_patient(mimic3_ds, [0.8, 0.1, 0.1])\n",
    "\n",
    "# create dataloaders (they are <torch.data.DataLoader> object)\n",
    "train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True)\n",
    "val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False)\n",
    "test_loader = get_dataloader(test_dataset, batch_size=1, shuffle=False)\n",
    "mimic3_ds.samples[0].keys()\n",
    "\n",
    "from pyhealth.models import Transformer\n",
    "model = Transformer(\n",
    "        dataset=mimic3_ds,\n",
    "        # look up what are available for \"feature_keys\" and \"label_keys\" in dataset.samples[0]\n",
    "        feature_keys=[\"conditions\", \"procedures\", \"drugs\"],\n",
    "        label_key=\"label\",\n",
    "        mode=\"multiclass\",\n",
    "    )\n",
    "\n",
    "print(\"Testing MIMIC3 STUFF\")\n",
    "sample = test_loader.dataset[0]\n",
    "\n",
    "print(sample)\n",
    "\n",
    "print(\"----\")\n",
    "print(model)\n",
    "# exit(0)\n",
    "from pyhealth.trainer import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    metrics=[\"accuracy\", \"f1_weighted\"], # the metrics that we want to log\n",
    "    )\n",
    "\n",
    "trainer.train(\n",
    "    train_dataloader=train_loader,\n",
    "    val_dataloader=val_loader,\n",
    "    epochs=30,\n",
    "    monitor=\"accuracy\",\n",
    "    monitor_criterion=\"max\",optimizer_class=torch.optim.AdamW\n",
    ")\n",
    "data_iterator = iter(test_loader)\n",
    "data = next(data_iterator)\n",
    "print(data)\n",
    "model(**data)\n",
    "\n",
    "relevance = CheferRelevance(model)\n",
    "# returns a list ofr now\n",
    "# interpretability code here!\n",
    "data['class_index'] = data['label']\n",
    "rel_scores = relevance.get_relevance_matrix(**data)\n",
    "\n",
    "# weigh and plot these scores and their corresponding feature list\n",
    "print(rel_scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot The Relevance of Each Feature\n",
    "\n",
    "\n",
    "#### Note that normally, there'd be multiple tokens and we would use the tokenizer to visualize explicitly which codes mattered to the prediction, but for now, this is what we get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.0, 1.0, 1.0881098508834839]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<BarContainer object of 3 artists>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAX2UlEQVR4nO3debxfdX3n8dc7QWQXaOIGhCCNtUgFNYILVVoQAadgFRdGRgUqM/MQxxmVPqhbLS6j0oftOKI1KkZxQVzQDEYZiyxuLGGLEMRmACUUJCCLSKsEP/PHORd+/HKXHyHn3iTn9Xw87iNn+Z7f+dzfub+8f2f7nlQVkqT+mjXTBUiSZpZBIEk9ZxBIUs8ZBJLUcwaBJPWcQSBJPWcQSCNKcl6Sv5ruZaWuGQTqpSQ3JDlwpusYk+SNSa5PcneSZUn2m+ma1B8GgTTDkuwLfAA4AngM8GngzCSzZ7Qw9YZBILWS7JDkrCSrk9zRDu881Gz3JBe339y/mWTHgeWfneRHSe5McmWS/Udc9Xzg6qq6tJpb/T8HzAEeux5+LWlKBoH0oFnAZ4BdgXnAvwEfHWrzGuAY4AnAGuAjAEl2Ar4FvBfYEXgr8LUkc4dXkmReGxbz2knfBmYn2bfdCzgGuAK4Zb3+dtIENpvpAqQNRVXdDnxtbDzJ+4Bzh5qdVlVXtfPfCVyR5LXAUcDSqlratvtukmXAocBnh9bzC2D7gUm/btf7AyDAncAhZUdgmibuEUitJFsl+USSnye5G7gA2H7oWP2NA8M/Bx5FcxhnV+Dl7Tf9O5PcCexHs+cwlWOBo4GnApvThMpZSZ74iH8paQQGgfSgtwB/BOxbVdsBz2+nZ6DNLgPD84D7gNtoAuK0qtp+4GfrqvrACOvdGzirqn5WVb+vqu8ANwPPfYS/jzQSg0B99qgkW4z9ADvQnBe4sz0J/LfjLHNUkj2SbAWcBHy1qu4HPg/8RZIXJZndvub+45xsHs8lwIuTPCmNFwJPBq5aL7+lNAWDQH22lOY//rGf7YEtab7hXwh8Z5xlTgMW05zI3QL4bwBVdSNwOPA2YDXNHsIJjPMZa08W3zNwsvhzwOnAecDdNCeg/3NV/fSR/4rS1OL5KEnqN/cIJKnnDAJJ6jmDQJJ6ziCQpJ7b6O4snjNnTs2fP3+my5Ckjcqll156W1Wt1eUJbIRBMH/+fJYtWzbTZUjSRiXJzyea56EhSeo5g0CSes4gkKSeMwgkqecMAknqOYNAknrOIJCknjMIJKnnDAJJ6rmN7s5iSRu2+Sd+a6ZL2GTd8IEXd/K67hFIUs8ZBJLUcwaBJPWcQSBJPWcQSFLPGQSS1HMGgST1nEEgST1nEEhSzxkEktRzBoEk9ZxBIEk911kQJDk1ya1JrppgfpJ8JMnKJMuTPKOrWiRJE+tyj2AxcPAk8w8BFrQ/xwEf77AWSdIEOguCqroA+NUkTQ4HPleNC4Htkzyhq3okSeObyecR7ATcODC+qp1283DDJMfR7DUwb968dV6h/aR3p6t+0t1m3elqm2njs1GcLK6qRVW1sKoWzp07d6bLkaRNykwGwU3ALgPjO7fTJEnTaCaDYAnwmvbqoWcDd1XVWoeFJEnd6uwcQZIvAfsDc5KsAv4WeBRAVf0TsBQ4FFgJ3Asc3VUtkqSJdRYEVXXkFPMLeENX65ckjWajOFksSeqOQSBJPWcQSFLPGQSS1HMGgST1nEEgST1nEEhSzxkEktRzBoEk9ZxBIEk9ZxBIUs8ZBJLUcwaBJPWcQSBJPWcQSFLPGQSS1HMGgST1nEEgST1nEEhSzxkEktRzBoEk9ZxBIEk9ZxBIUs8ZBJLUcwaBJPWcQSBJPWcQSFLPGQSS1HMGgST1XKdBkOTgJNcmWZnkxHHmz0tybpLLkyxPcmiX9UiS1tZZECSZDZwCHALsARyZZI+hZu8AzqiqpwOvAj7WVT2SpPF1uUewD7Cyqq6rqt8BpwOHD7UpYLt2+DHAv3ZYjyRpHJt1+No7ATcOjK8C9h1q827g/yZ5I7A1cGCH9UiSxjHTJ4uPBBZX1c7AocBpSdaqKclxSZYlWbZ69eppL1KSNmVdBsFNwC4D4zu30wYdC5wBUFU/BrYA5gy/UFUtqqqFVbVw7ty5HZUrSf3UZRBcAixIsluSzWlOBi8ZavML4ACAJH9MEwR+5ZekadRZEFTVGuB44GzgGpqrg65OclKSw9pmbwFen+RK4EvA66qquqpJkrS2Lk8WU1VLgaVD0941MLwCeF6XNUiSJjfTJ4slSTPMIJCknjMIJKnnDAJJ6jmDQJJ6ziCQpJ4zCCSp5wwCSeo5g0CSes4gkKSeMwgkqecMAknqOYNAknrOIJCknjMIJKnnDAJJ6rmRgiDJk5Ock+SqdvxpSd7RbWmSpOkw6h7BJ4G/Ae4DqKrlNM8gliRt5EYNgq2q6uKhaWvWdzGSpOk3ahDclmR3oACSHAHc3FlVkqRpM+rD698ALAKekuQm4HrgqM6qkiRNm5GCoKquAw5MsjUwq6p+3W1ZkqTpMupVQ+9Psn1V/aaqfp1khyTv7bo4SVL3Rj1HcEhV3Tk2UlV3AId2UpEkaVqNGgSzkzx6bCTJlsCjJ2kvSdpIjHqy+AvAOUk+044fDXy2m5IkSdNp1JPFH0yyHDignfSeqjq7u7IkSdNl1D0CqurbwLc7rEWSNANGvWropUn+JcldSe5O8uskd3ddnCSpe6OeLP4QcFhVPaaqtquqbatqu6kWSnJwkmuTrExy4gRtXpFkRZKrk3zx4RQvSXrkRj009MuquubhvHCS2cApwAuBVcAlSZZU1YqBNgtoOrN7XlXdkeSxD2cdkqRHbtQgWJbky8A3gN+OTayqr0+yzD7AyvauZJKcDhwOrBho83rglPa+BKrq1tFLlyStD6MGwXbAvcBBA9MKmCwIdgJuHBhfBew71ObJAEl+CMwG3l1V3xmxJknSejDq5aNHd7j+BcD+wM7ABUn+ZPAuZoAkxwHHAcybN6+jUiSpn0YKgiRbAMcCTwW2GJteVcdMsthNwC4D4zu30watAi6qqvuA65P8jCYYLhlsVFWLaHo/ZeHChTVKzZKk0Yx61dBpwOOBFwHn0/ynPlUPpJcAC5LslmRzmieaLRlq8w2avQGSzKE5VHTdiDVJktaDUYPgD6vqncBvquqzwItZ+3j/Q1TVGuB44GzgGuCMqro6yUlJDmubnQ3cnmQFcC5wQlXdvi6/iCRp3Yx6svi+9t87k+wJ3AJMealnVS0Flg5Ne9fAcAFvbn8kSTNg1CBYlGQH4J00h3e2aYclSRu5UYPgM1V1P835gSd1WI8kaZqNeo7g+iSLkhyQJJ1WJEmaVqMGwVOAf6Z5iP0NST6aZL/uypIkTZeRgqCq7q2qM6rqpcDeNHcan99lYZKk6THqHgFJXpDkY8ClNDeVvaKzqiRJ02bUO4tvAC4HzqC51v83XRYlSZo+o1419LSq8kE0krQJGvXQ0OOTnJPkKoAkT0vyjg7rkiRNk1GD4JM0D5C5D6CqltP0HSRJ2siNGgRbVdXFQ9PWrO9iJEnTb9QguC3J7jQPoyHJEcDNnVUlSZo2o54sfgPN8wCekuQm4HrgqM6qkiRNm1GfUHYdcGCSrYFZVTXVswgkSRuJSYMgybjdQ491N1RVH+6gJknSNJpqj2DbaalCkjRjJg2Cqvq76SpEkjQzRrpqKMmTvaFMkjZN3lAmST3nDWWS1HPeUCZJPfdIbih7dWdVSZKmzTrdUAbcS3OO4Ocd1iZJmgaTHhpKsl2Sv2mfUfxCmgB4LbASn1AmSZuEqfYITgPuAH4MvB54OxDgL6vqim5LkyRNh6mC4ElV9ScAST5Fc4J4XlX9e+eVSZKmxVRXDd03NlBV9wOrDAFJ2rRMtUewV5KxZxUH2LIdD1BVtV2n1UmSOjdVX0Ozp6sQSdLMGPWGMknSJqrTIEhycJJrk6xMcuIk7V6WpJIs7LIeSdLaOguCJLOBU4BDgD2AI5PsMU67bYE3ARd1VYskaWJd7hHsA6ysquuq6nfA6cDh47R7D/BBwKuRJGkGdBkEOwE3Doyvaqc9IMkzgF2q6luTvVCS45IsS7Js9erV679SSeqxGTtZnGQW8GHgLVO1rapFVbWwqhbOnTu3++IkqUe6DIKbgF0Gxndup43ZFtgTOC/JDcCzgSWeMJak6dVlEFwCLEiyW5LNaXorXTI2s6ruqqo5VTW/quYDFwKHVdWyDmuSJA3pLAiqag1wPHA2cA1wRlVdneSkJId1tV5J0sMz6oNp1klVLQWWDk171wRt9++yFknS+LyzWJJ6ziCQpJ4zCCSp5wwCSeo5g0CSes4gkKSeMwgkqecMAknqOYNAknrOIJCknjMIJKnnDAJJ6jmDQJJ6ziCQpJ4zCCSp5wwCSeo5g0CSes4gkKSeMwgkqecMAknqOYNAknrOIJCknjMIJKnnDAJJ6jmDQJJ6ziCQpJ4zCCSp5wwCSeo5g0CSeq7TIEhycJJrk6xMcuI489+cZEWS5UnOSbJrl/VIktbWWRAkmQ2cAhwC7AEcmWSPoWaXAwur6mnAV4EPdVWPJGl8Xe4R7AOsrKrrqup3wOnA4YMNqurcqrq3Hb0Q2LnDeiRJ4+gyCHYCbhwYX9VOm8ixwLfHm5HkuCTLkixbvXr1eixRkrRBnCxOchSwEDh5vPlVtaiqFlbVwrlz505vcZK0idusw9e+CdhlYHzndtpDJDkQeDvwgqr6bYf1SJLG0eUewSXAgiS7JdkceBWwZLBBkqcDnwAOq6pbO6xFkjSBzoKgqtYAxwNnA9cAZ1TV1UlOSnJY2+xkYBvgK0muSLJkgpeTJHWky0NDVNVSYOnQtHcNDB/Y5folSVPbIE4WS5JmjkEgST1nEEhSzxkEktRzBoEk9ZxBIEk9ZxBIUs8ZBJLUcwaBJPWcQSBJPWcQSFLPGQSS1HMGgST1nEEgST1nEEhSzxkEktRzBoEk9ZxBIEk9ZxBIUs8ZBJLUcwaBJPWcQSBJPWcQSFLPGQSS1HMGgST1nEEgST1nEEhSzxkEktRzBoEk9VynQZDk4CTXJlmZ5MRx5j86yZfb+Rclmd9lPZKktXUWBElmA6cAhwB7AEcm2WOo2bHAHVX1h8A/AB/sqh5J0vi63CPYB1hZVddV1e+A04HDh9ocDny2Hf4qcECSdFiTJGnIZh2+9k7AjQPjq4B9J2pTVWuS3AX8AXDbYKMkxwHHtaP3JLm2k4o3PHMYei82VHFfDjai7QVus1afttmuE83oMgjWm6paBCya6TqmW5JlVbVwpuvQaNxeGx+3WaPLQ0M3AbsMjO/cThu3TZLNgMcAt3dYkyRpSJdBcAmwIMluSTYHXgUsGWqzBHhtO3wE8L2qqg5rkiQN6ezQUHvM/3jgbGA2cGpVXZ3kJGBZVS0BPg2clmQl8CuasNCDenc4bCPn9tr4uM2A+AVckvrNO4slqecMAknqOYNgA5JkcZIj2uFPjd2JneRtQ+1+NBP19VmSdyd560zXoW4lOS9J7y4nNQg2UFX1V1W1oh1929C8585ASRrSXvKsGdB2YTPjNpW/AYPgEUjymiTLk1yZ5LQk85N8r512TpJ5bbvFST6S5EdJrhv41p8kH2075vtn4LEDr31ekoVJPgBsmeSKJF9o590zsPzJSa5K8pMkr2yn798u/9UkP03yhbGuO5J8IMmKtsa/n953bOOS5O1JfpbkB8AftdPOS/KPSZYBbxrci2vnj22bWUk+1r7/302ydGC7uw0m0X6Oxv5ur2n/jrdKckOSDya5DHh5kiPbv/urkgfvuW07u7ys/Vye007bOsmpSS5OcnmSw9vpWyY5vV3PmcCWA69zz8DwEUkWt8OLk/xTkouADyXZPcl3klya5PtJntK2e3lb25VJLpiO926dVZU/6/ADPBX4GTCnHd8R+D/Aa9vxY4BvtMOLga/QBO8eNH0wAbwU+C7N5bVPBO4EjmjnnQcsbIfvGVr3Pe2/LxtY/nHAL4AnAPsDd9HcxDcL+DGwH033Hdfy4NVi28/0+7ih/gDPBH4CbAVsB6wE3tpul48NtFs8ts2Gts0RwNL2/X88cEc7zW0w9Xs/Hyjgee34qe17fwPw1+20J7Z/73NpLoP/HvCSdvxGYLe23Y7tv+8Hjhp7z9vP7tbAm2kubQd4GrBmvM9du+0WD2zzs4DZ7fg5wIJ2eF+a+6Fo/3522hi2s3sE6+7Pga9U1W0AVfUr4DnAF9v5p9H85zvmG1X1+2oO9zyunfZ84EtVdX9V/SvNH/PDsd/A8r8Ezgee1c67uKpWVdXvgStoPlx3Af8OfDrJS4F7H+b6+uRPgTOr6t6qupuH3gz55RGW34/m7+P3VXULcG473W0wmhur6oft8Od58LM09t4/CzivqlZX1RrgCzSfp2cDF1TV9fDA5xLgIODEJFfQhPkWwLx2mc+3bZcDy0es7ytVdX+SbYDnAl9pX/sTNF/GAH4ILE7yepovaxusTeL41kbitwPD09HD6uD67gc2q+Ymv32AA2i+4RxPE2h6eH4zMLyG9hBrklnA5pMt6DYY2fANTmPjvxluOKIAL6uqh3RYmck7Ox6sYYuheWN1zALurKq911q46r8k2Rd4MXBpkmdW1QbZhY57BOvuezTHKf8AIMmOwI948O7oVwPfn+I1LgBemWR2kicAfzZBu/uSPGqc6d8fWH4uzbebiydaWfvt5TFVtRT4H8BeU9TXZxcAL2mPIW8L/MUE7W6gOYwEcBgwtp1+CLysPVfwOJrDdW6D0c1L8px2+D8CPxiafzHwgiRz0pw4PpJmj/hC4PlJdoMHPpfQ9HDwxoFzZU9vp1/Qvj5J9qQ5PDTml0n+uA34vxyvyHZv8fokL29fI0n2aod3r6qLqupdwGoe2vfaBsU9gnVUTXcZ7wPOT3I/cDnwRuAzSU6g2fBHT/EyZ9J8G1xBc7zzxxO0WwQsT3JZVb16aPnnAFfSfHv566q6Zexk1Ti2Bb6ZZAuab0hvnur37KuquizJl2ne21tp+s4azydp3tMrge/w4DfFr9F8619Bc8z6MprDQm6D0VwLvCHJqTTv4cdpPl8AVNXNaZ56eC7N+/itqvomPNBt/dfb/8BvBV4IvAf4R5rP0SzgeuA/tK/7mSTXANcAlw7UcCLNuYDVwDJgmwlqfTXw8STvoPkicDrN383JSRa09Z3TTtsg2cWE1JEk21TVPe1e48U0Jz9vmem6NnRpHll7VlXtOdO19IV7BFJ3zkqyPc15g/cYAtpQuUcgST3nyWJJ6jmDQJJ6ziCQpJ4zCLRJSXJ/mn6Zxn7mr8NrvCRtz6/rW9uPzlUD469v+6jZoYv1SaPwqiFtav5tvLs8H6aX0Fw/vmKKdg9Islnb1cHIkvwnmmvj/7yq7nhYFUrrkXsE2uQleWaS89tv3me3d3GPfRu/pO0d8mtperh8Ls0dwie3exS7Z6CP+vZO1hva4dclWZLke8A5maCHywlqegXNDUsHjfVXleSEtp7lSf6unXZSkv8+sNz7krwpyROSXNDWeFWSP+3kzVMvGATa1Ix12X1FkjPbrjn+N00Poc+k6cnyfW3br1fVs6pqL5q7So+tqh/RdDB3QlXtXVX/b4r1PaN97RcAb6fpeXIfmu5CTk6y9TjL7Ap8lCYEbgFIchCwANgH2Bt4ZpLnt/W+pm0zi6YLk8/TdItwdrv3sxdNx4LSOvHQkDY1Dzk01PYfsyfw3babmdnAze3sPZO8l6Zb4m1o+qN5uL471MPlYXnwSWZjPVxeM7TMauBXwCuAfxhY9iCarkpo61lQVRckub3tG+dxwOVVdXuSS4BT26D7RlVdsQ61S4BBoE1fgKur6jnjzFsMvKSqrkzyOtqO4cbxQA+jTNwL5di61urhchz3AocC309ya1V9oV32f1bVJ8Zp/yngdTTPNTgVoA2I59P0bLk4yYer6nNTrFcal4eGtKm7Fpg71pNlkkcleWo7b1vg5vZb9WBnfr9u5425gQd7GD2CiU3Uw+VaqupW4GDg/Ule1C57TNs7KUl2SjL2xLoz27bPatuRZFfgl1X1SZqgeMYkdUmTMgi0Sauq39H85/3BtofQK2geJALwTuAimi6jfzqw2OnACe0J392Bvwf+a5LLgTmTrO49NL1PLk9ydTs+WW3X05yYPpXm6XRfBH6c5CfAV2nDqP0dzgXOqKr728X3B65sa3ol8L8mfSOkSdjXkLSBa08SXwa8vKr+Zabr0abHPQJpA9be2LYSOMcQUFfcI5CknnOPQJJ6ziCQpJ4zCCSp5wwCSeo5g0CSeu7/Az4Iq4lxMu1hAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "barPlot = []\n",
    "barCategories = []\n",
    "for key in sorted(rel_scores.keys()):\n",
    "    barPlot.append(float(rel_scores[key].view(-1).detach().cpu()))\n",
    "    barCategories.append(key)\n",
    "\n",
    "print(barPlot)\n",
    "plt.xlabel(\"Feature Keys\")\n",
    "plt.ylabel(\"Relevance\")\n",
    "plt.title(\"Label:\" + str(sample['label']))\n",
    "plt.bar(barCategories, barPlot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
